package com.simple.rpc.common.network;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
 * 项目: simple-rpc
 *
 * 功能描述: 自己实现的ThreadLocal
 *
 * @author: WuChengXing
 *
 * @create: 2023-08-18 17:16
 **/
public class SimpleThreadLocal<T> extends InheritableThreadLocal<T> {

    private static Map<Thread, Map<SimpleThreadLocal<?>, Object>> cacheSimpleThreadLocal = new HashMap<>();

    public SimpleThreadLocal() {
    }

    public static Map<SimpleThreadLocal<?>, Object> copy() {
        Thread t = Thread.currentThread();
        return cacheSimpleThreadLocal.get(t);
    }

    /**
     * 初始化这个线程run之前的搜歌threadLocal
     *
     * @param cache
     */
    public static void initChildThreadLocal(Map<SimpleThreadLocal<?>, Object> cache) {
        cacheSimpleThreadLocal.put(Thread.currentThread(), cache);
    }

    /**
     * 删除线程相关的所有threadLocal
     */
    public static void removeChildThreadLocal() {
        cacheSimpleThreadLocal.remove(Thread.currentThread());
    }


    @Override
    public T get() {
        Thread thread = Thread.currentThread();
        Optional<Map<SimpleThreadLocal<?>, Object>> SimpleThreadLocalObjectMap = Optional.ofNullable(cacheSimpleThreadLocal.get(thread));
        return SimpleThreadLocalObjectMap.map(t -> (T) t.get(this)).orElse(super.get());
    }

    @Override
    public void set(T value) {
        Thread thread = Thread.currentThread();
        Map<SimpleThreadLocal<?>, Object> SimpleThreadLocalObjectMap = cacheSimpleThreadLocal.computeIfAbsent(thread, k -> new HashMap<>());
        SimpleThreadLocalObjectMap.put(this, value);
        super.set(value);
    }

    @Override
    public void remove() {
        Thread thread = Thread.currentThread();
        Map<SimpleThreadLocal<?>, Object> SimpleThreadLocalObjectMap = cacheSimpleThreadLocal.get(thread);
        if (SimpleThreadLocalObjectMap == null) {
            return;
        }
        SimpleThreadLocalObjectMap.remove(this);
    }
}
